/* Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * ===================================================================================================================*/

#include "base/type/types_impl.h"
#include <cmath>
#include <utility>
#include "common/ge_common/debug/ge_log.h"
#include "graph/utils/type_utils.h"

namespace ge {

const char_t *TypeImpl::GetFormatName(Format format) {
  static const char_t *names[FORMAT_END] = {
      "NCHW",
      "NHWC",
      "ND",
      "NC1HWC0",
      "FRACTAL_Z",
      "NC1C0HWPAD",  // 5
      "NHWC1C0",
      "FSR_NCHW",
      "FRACTAL_DECONV",
      "C1HWNC0",
      "FRACTAL_DECONV_TRANSPOSE",  // 10
      "FRACTAL_DECONV_SP_STRIDE_TRANS",
      "NC1HWC0_C04",
      "FRACTAL_Z_C04",
      "CHWN",
      "DECONV_SP_STRIDE8_TRANS",  // 15
      "HWCN",
      "NC1KHKWHWC0",
      "BN_WEIGHT",
      "FILTER_HWCK",
      "LOOKUP_LOOKUPS",  // 20
      "LOOKUP_KEYS",
      "LOOKUP_VALUE",
      "LOOKUP_OUTPUT",
      "LOOKUP_HITS",
      "C1HWNCoC0",  // 25
      "MD",
      "NDHWC",
      "UNKNOWN",  // FORMAT_FRACTAL_ZZ
      "FRACTAL_NZ",
      "NCDHW",  // 30
      "DHWCN",
      "NDC1HWC0",
      "FRACTAL_Z_3D",
      "CN",
      "NC",  // 35
      "DHWNC",
      "FRACTAL_Z_3D_TRANSPOSE",
      "FRACTAL_ZN_LSTM",
      "FRACTAL_Z_G",
      "UNKNOWN",  // 40, FORMAT_RESERVED
      "UNKNOWN",  // FORMAT_ALL
      "UNKNOWN",  // FORMAT_NULL
      "ND_RNN_BIAS",
      "FRACTAL_ZN_RNN",
      "NYUV",  // 45
      "NYUV_A",
      "NCL",
      "FRACTAL_Z_WINO",
      "C1HWC0",
      "FRACTAL_NZ_C0_16",
      "FRACTAL_NZ_C0_32",
  };
  if (format >= FORMAT_END) {
    return "UNKNOWN";
  }
  return names[format];
}

int64_t TypeImpl::CeilDiv(const int64_t n1, const int64_t n2) {
  if (n1 == 0) {
    return 0;
  }
  return (n2 != 0) ? (((n1 - 1) / n2) + 1) : 0;
}

Status TypeImpl::CheckInt64MulOverflow(const int64_t a, const int64_t b) {
  if (a > 0) {
    if (b > 0) {
      if (a > (INT64_MAX / b)) {
        return FAILED;
      }
    } else {
      if (b < (INT64_MIN / a)) {
        return FAILED;
      }
    }
  } else {
    if (b > 0) {
      if (a < (INT64_MIN / b)) {
        return FAILED;
      }
    } else {
      if ((a != 0) && (b < (INT64_MAX / a))) {
        return FAILED;
      }
    }
  }
  return SUCCESS;
}

int64_t TypeImpl::GetSizeInBytes(int64_t element_count, DataType data_type) {
  if (element_count < 0) {
    GELOGW("[Check][param]GetSizeInBytes failed, element_count:%" PRId64 " less than 0.", element_count);
    return -1;
  }
  uint32_t type_size = 0U;
  if (!TypeUtils::GetDataTypeLength(data_type, type_size)) {
    GELOGW("[Check][DataType]GetSizeInBytes failed, data_type:%d not support.", data_type);
    return -1;
  } else if (type_size > static_cast<uint32_t>(kDataTypeSizeBitOffset)) {
    const uint32_t bit_size = type_size - static_cast<uint32_t>(kDataTypeSizeBitOffset);
    if (CheckInt64MulOverflow(element_count, static_cast<int64_t>(bit_size)) == FAILED) {
      GELOGW("[Check][overflow]GetSizeInBytes failed, when multiplying %" PRId64 " and %d.", element_count, bit_size);
      return -1;
    }
    return CeilDiv(element_count * static_cast<int64_t>(bit_size), static_cast<int64_t>(kBitNumOfOneByte));
  } else {
    if (CheckInt64MulOverflow(element_count, static_cast<int64_t>(type_size)) == FAILED) {
      GELOGW("[Check][overflow]GetSizeInBytes failed, when multiplying %" PRId64 " and %" PRId32 ".", element_count,
             type_size);
      return -1;
    }
    return element_count * static_cast<int64_t>(type_size);
  }
}

void PromoteImpl::Construct(Promote &obj, const std::initializer_list<const char *> &syms) {
  obj.data_ = std::shared_ptr<void>(new (std::nothrow) std::vector<std::string>(),
                                    [](void * const ptr) { delete static_cast<std::vector<std::string> *>(ptr); });
  if (obj.data_ != nullptr) {
    for (const auto &sym : syms) {
      (void) static_cast<std::vector<std::string> *>(obj.data_.get())->emplace_back((sym == nullptr) ? "" : sym);
    }
  }
}

std::vector<const char *> PromoteImpl::Syms(const Promote &obj) {
  std::vector<const char *> result;
  if (obj.data_ == nullptr) {
    return result;
  }
  auto &syms = *static_cast<std::vector<std::string> *>(obj.data_.get());
  result.reserve(syms.size());
  for (const auto &sym : syms) {
    result.push_back(sym.c_str());
  }
  return result;
}

void PromoteImpl::MoveConstruct(Promote &obj, Promote &&other) noexcept {
  obj.data_ = std::move(other.data_);
}

Promote &PromoteImpl::MoveAssign(Promote &obj, Promote &&other) noexcept {
  if (&obj != &other) {
    obj.data_ = std::move(other.data_);
  }
  return obj;
}

const char_t *GetFormatName(Format format) {
  return TypeImpl::GetFormatName(format);
}

int64_t GetSizeInBytes(int64_t element_count, DataType data_type) {
  return TypeImpl::GetSizeInBytes(element_count, data_type);
}

std::vector<const char *> Promote::Syms() const {
  return PromoteImpl::Syms(*this);
}

Promote::Promote(const std::initializer_list<const char *> &syms) {
  PromoteImpl::Construct(*this, syms);
}

Promote::Promote(Promote &&other) noexcept {
  PromoteImpl::MoveConstruct(*this, std::move(other));
}

Promote &Promote::operator=(Promote &&other) noexcept {
  return PromoteImpl::MoveAssign(*this, std::move(other));
}
}  // namespace ge
